#include "WholeBodyDynamicCalculator.h"

#include <MMMSimoxTools/MMMSimoxTools.h>
#include <VirtualRobot/RobotNodeSet.h>
#include <VirtualRobot/IK/DifferentialIK.h>
#include <MMM/Motion/Plugin/KinematicPlugin/KinematicSensor.h>
#include <MMM/Motion/Plugin/ModelPosePlugin/ModelPoseSensor.h>

using namespace MMM;

void WholeBodyDynamicCalculator::calculate(MotionPtr motion) {
    WholeBodyDynamicSensorPtr sensor = calculateWholeBodyDynamicSensor(motion);
    motion->addSensor(sensor);
}

WholeBodyDynamicSensorPtr WholeBodyDynamicCalculator::calculateWholeBodyDynamicSensor(MotionPtr motion) {
    if (motion->getSensorByType(WholeBodyDynamicSensor::TYPE))
        throw Exception::MMMException("WholeBodyDynamicSensor already in motion " + motion->getName());

    ModelPoseSensorPtr modelPoseSensor = motion->getSensorByType<ModelPoseSensor>(ModelPoseSensor::TYPE);
    if (!modelPoseSensor) throw Exception::MMMException("No model sensor in motion " + motion->getName() + " found. WholeBodyDynamicCalculator will skip this motion.");
    std::vector<float> timesteps = modelPoseSensor->getTimesteps();

    KinematicSensorPtr kinematicSensor = KinematicSensor::join(motion->getSensorsByType<KinematicSensor>(KinematicSensor::TYPE), timesteps);

    WholeBodyDynamicSensorPtr wholeBodyDynamicSensor(new WholeBodyDynamicSensor());

    // Create Simox objects
    VirtualRobot::RobotPtr robot = SimoxTools::buildModel(motion->getModel(), false);
    VirtualRobot::RobotNodeSetPtr robotNodeSet = VirtualRobot::RobotNodeSet::createRobotNodeSet(robot, "MMMWholeBodyDynamicCalculatorRobotNodeSet",
                                                                                                kinematicSensor->getJointNames(), "", "", true);

    // Determine segments to be used for Angular Momentum calculation
    std::vector<std::string> segmentsWithMass;
    std::vector<VirtualRobot::RobotNodePtr> robotNodes = robot->getRobotNodes();
    for (std::vector<VirtualRobot::RobotNodePtr>::const_iterator i = robotNodes.begin(); i != robotNodes.end(); ++i)
    {
        std::string segmentName = (*i)->getName();
        float mass = (*i)->getMass();

        MMM_INFO << "Segment " << segmentName << ": Mass = " << mass << "kg" << std::endl;

        if (mass > 0)
        {
            segmentsWithMass.push_back(segmentName);
        }
    }

    // Pre-compute segment center of mass, position, rotation and jacobian for all timesteps
    MMM_INFO << "Pre-computing data for all segments in all frames..." << std::endl;
    VirtualRobot::DifferentialIKPtr diffIK(new VirtualRobot::DifferentialIK(robotNodeSet));
    std::vector<std::map<std::string, FrameSegmentData> > frameSegmentData;
    std::vector<Eigen::Vector3f> centerOfMasses;

    for (auto timestep : timesteps)
    {
        ModelPoseSensorMeasurementPtr modelPoseSensorMeasurement = modelPoseSensor->getDerivedMeasurement(timestep);
        KinematicSensorMeasurementPtr kinematicSensorMeasurement = kinematicSensor->getDerivedMeasurement(timestep);
        if (!kinematicSensorMeasurement) continue;

        // Set pose and joint values
        robot->setGlobalPose(modelPoseSensorMeasurement->getRootPose());
        robotNodeSet->setJointValues(kinematicSensorMeasurement->getJointAngles());
        centerOfMasses.push_back(robot->getCoMGlobal());

        std::map<std::string, FrameSegmentData> fsdMap;
        for (std::vector<std::string>::const_iterator curSegment = segmentsWithMass.begin(); curSegment != segmentsWithMass.end(); ++curSegment)
        {
            VirtualRobot::RobotNodePtr segment = robot->getRobotNode(*curSegment);

            FrameSegmentData fsd;
            fsd.position = segment->getCoMGlobal();
            fsd.rotation = segment->getGlobalPose().block(0, 0, 3, 3);
            fsd.jacobian = diffIK->getJacobianMatrix(segment);

            /* MMM_INFO << "Frame " << i << ", segment " << *curSegment << ": pos = " << fsd.position.transpose() << " rot = "
                     << fsd.rotation.transpose() << " jac = " << fsd.jacobian << endl; */
            fsdMap[*curSegment] = fsd;
        }

        frameSegmentData.push_back(fsdMap);
    }

    int numTimesteps = timesteps.size();
    float timestepDelta = timesteps[1] - timesteps[0];
    for (int i = 0; i < numTimesteps; i++)
    {
        ModelPoseSensorMeasurementPtr modelPoseSensorMeasurement = modelPoseSensor->getDerivedMeasurement(timesteps[i]);
        KinematicSensorMeasurementPtr kinematicSensorMeasurement = kinematicSensor->getDerivedMeasurement(timesteps[i]);
        if (!kinematicSensorMeasurement) continue;

        // Set pose and joint values
        robot->setGlobalPose(modelPoseSensorMeasurement->getRootPose());
        robotNodeSet->setJointValues(kinematicSensorMeasurement->getJointAngles());

        Eigen::Vector3f angularMomentum;
        angularMomentum.setZero();

        for (std::vector<std::string>::const_iterator curSegment = segmentsWithMass.begin(); curSegment != segmentsWithMass.end(); ++curSegment)
        {
            VirtualRobot::RobotNodePtr segment = robot->getRobotNode(*curSegment);

            Eigen::Matrix3f robotOrientation = robot->getGlobalPose().block(0, 0, 3, 3); // TODO: Das selbe wie RootRotation?

            // Calculate linear velocity
            Eigen::Vector3f segmentVelocity, CoMVelocity;
            if (i == 0)
            {
                segmentVelocity = frameSegmentData[1][*curSegment].position - frameSegmentData[0][*curSegment].position;
                CoMVelocity = centerOfMasses[1] - centerOfMasses[0];
            }
            else if (i == numTimesteps - 1)
            {
                segmentVelocity = frameSegmentData[numTimesteps - 1][*curSegment].position - frameSegmentData[numTimesteps - 2][*curSegment].position;
                CoMVelocity = centerOfMasses[numTimesteps - 1] - centerOfMasses[numTimesteps - 2];
            }
            else
            {
                segmentVelocity = (frameSegmentData[i + 1][*curSegment].position - frameSegmentData[i - 1][*curSegment].position) / 2;
                CoMVelocity = (centerOfMasses[i + 1] - centerOfMasses[i - 1]) / 2;
            }

            Eigen::Vector3f linearVelocity = (segmentVelocity - CoMVelocity) / timestepDelta;

            // Calculate angular velocity
            Eigen::VectorXf jointVelocity;

            if (i == 0)
                jointVelocity = kinematicSensor->getDerivedMeasurement(timesteps[i + 1])->getJointAngles() - kinematicSensorMeasurement->getJointAngles();
            else if (i == numTimesteps - 1)
                jointVelocity = kinematicSensorMeasurement->getJointAngles() - kinematicSensor->getDerivedMeasurement(timesteps[i - 1])->getJointAngles();
            else
                jointVelocity = (kinematicSensor->getDerivedMeasurement(timesteps[i + 1])->getJointAngles() - kinematicSensor->getDerivedMeasurement(timesteps[i - 1])->getJointAngles()) / 2;

            jointVelocity /= timestepDelta;

            Eigen::Vector3f angularVelocity = (frameSegmentData[i][*curSegment].jacobian * jointVelocity).segment<3>(3);

            // Calculate inertia tensor
            Eigen::Matrix3f inertiaTensor = segment->getInertiaMatrix();
            inertiaTensor *= 1000000.0f;  // m^2 -> mm^2

            // Add contribution of this segment to angular momentum to sum
            Eigen::Vector3f segmentAM = (robotOrientation.inverse() * (segment->getCoMGlobal() - centerOfMasses[i]).
                                         cross(segment->getMass() * linearVelocity)) + inertiaTensor * angularVelocity;
            angularMomentum += segmentAM;
        }

        angularMomentum /= 1000000.0f;  // mm^2 -> m^2

        WholeBodyDynamicSensorMeasurementPtr wholeBodyDynamicSensorMeasurement(new WholeBodyDynamicSensorMeasurement(timesteps[i], centerOfMasses[i], angularMomentum));
        wholeBodyDynamicSensor->addSensorMeasurement(wholeBodyDynamicSensorMeasurement);
    }

    return wholeBodyDynamicSensor;
}
